import os
import random
import shutil


def split(filelist, test_num):
    random.shuffle(filelist)
    return filelist[:-test_num], filelist[-test_num:]


src_root = 'oregon_wildlife_original'
classes = os.listdir(src_root)

dst_root = 'oregon_wildlife'


for each in classes:
    src_path = os.path.join(src_root, each)
    if not os.path.isdir(src_path):
        continue
    dst_train_path = os.path.join(dst_root, 'train', each)
    dst_test_path = os.path.join(dst_root, 'test', each)
    os.makedirs(dst_train_path, exist_ok=True)
    os.makedirs(dst_test_path, exist_ok=True)
    imgs = os.listdir(src_path)
    train_imgs, test_imgs = split(imgs, 50)
    train_imgs = train_imgs[:300]
    test_imgs = test_imgs[:50]
    print(f'class {each}: {len(train_imgs)} train , {len(test_imgs)} test')
    for img in train_imgs:
        src_train_img = os.path.join(src_path, img)
        dst_train_img = os.path.join(dst_train_path, img)
        shutil.copyfile(src_train_img, dst_train_img)
    for img in test_imgs:
        src_test_img = os.path.join(src_path, img)
        dst_test_img = os.path.join(dst_test_path, img)
        shutil.copyfile(src_test_img, dst_test_img)
